GAN Model za generisanje slike¶

Zahtjevi:¶

  • Računar sa CUDA podržljivom grafičkom karticom
  • Anaconda distribucija za pokretanje Jupyter nootebook-a sa okruženjem koje podržava CUDA
  • Python 3
  • Restartovati nakon pokretanja da se obriše zauzeta memorija u VRAM

Reference¶

UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS - Alec Radford & Luke Metz https://arxiv.org/pdf/1511.06434.pdf

In [1]:
# Instalacija paketa

!pip install opencv-python
# paket za manipuliranje slika cv2
!pip install numpy
# paket koji sadrzi funkcije za manipuliranje nizova i ostalih objekata
!pip install matplotlib 
# paket koji se koristi za prikaz grafova
!pip install torch
# PyTorch - open source paket za masinsko programiranje preko kojeg se kreira GAN model
!pip install torchvision
# Sadrži popularne baze podataka koji se koriste za testiranje
Requirement already satisfied: opencv-python in c:\users\user\appdata\roaming\python\python39\site-packages (4.6.0.66)
Requirement already satisfied: numpy>=1.19.3 in c:\users\user\anaconda3\lib\site-packages (from opencv-python) (1.21.5)
Requirement already satisfied: numpy in c:\users\user\anaconda3\lib\site-packages (1.21.5)
Requirement already satisfied: matplotlib in c:\users\user\anaconda3\lib\site-packages (3.5.2)
Requirement already satisfied: pyparsing>=2.2.1 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (3.0.9)
Requirement already satisfied: fonttools>=4.22.0 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (4.25.0)
Requirement already satisfied: numpy>=1.17 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (1.21.5)
Requirement already satisfied: pillow>=6.2.0 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (9.2.0)
Requirement already satisfied: python-dateutil>=2.7 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (2.8.2)
Requirement already satisfied: packaging>=20.0 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (21.3)
Requirement already satisfied: cycler>=0.10 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (1.4.2)
Requirement already satisfied: six>=1.5 in c:\users\user\anaconda3\lib\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: torch in c:\users\user\anaconda3\lib\site-packages (1.13.1)
Requirement already satisfied: typing_extensions in c:\users\user\anaconda3\lib\site-packages (from torch) (4.3.0)
Requirement already satisfied: torchvision in c:\users\user\anaconda3\lib\site-packages (0.14.1)
Requirement already satisfied: typing_extensions in c:\users\user\anaconda3\lib\site-packages (from torchvision) (4.3.0)
Requirement already satisfied: numpy in c:\users\user\anaconda3\lib\site-packages (from torchvision) (1.21.5)
Requirement already satisfied: requests in c:\users\user\anaconda3\lib\site-packages (from torchvision) (2.28.1)
Requirement already satisfied: torch==1.13.1 in c:\users\user\anaconda3\lib\site-packages (from torchvision) (1.13.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\users\user\anaconda3\lib\site-packages (from torchvision) (9.2.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (1.26.11)
Requirement already satisfied: charset-normalizer<3,>=2 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (2.0.4)
Requirement already satisfied: certifi>=2017.4.17 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (2022.9.14)
Requirement already satisfied: idna<4,>=2.5 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (3.3)
Requirement already satisfied: tqdm in c:\users\user\anaconda3\lib\site-packages (4.64.1)
Requirement already satisfied: colorama in c:\users\user\anaconda3\lib\site-packages (from tqdm) (0.4.5)

Importovanje bibliokteka i učitavanje podataka¶

In [2]:
# Manipulacija slika
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Kreiranje neuronski mreža
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import torchvision.utils as vutils
import torchvision
from torch.autograd import Variable
# Prikaz rezultata
from tqdm import tqdm
import matplotlib.animation as animation
from IPython.display import HTML

%matplotlib inline

Skup podataka¶

Treniranje mreže i generisanje slika vrši se koristeći 3 teme:

  1. Pokemon čudovišta
  2. Serija Lud, Zbunjen i Normalan
  3. Klix vijesti (slike članaka)
  4. Ostalo - Jednostavni skupovi podataka za lakše testiranje
  • Za dobro trenirani GAN model potrebna je velika količina slika fiksiranih na određenu dimenziju, nakon sinteze slika moguće je te iste slike izmijeniti zavisno od željene dimenzije
In [96]:
# Dataset 
pokemon_data_directory = './datasets/pokemon/'
print('Broj podataka u pokemon datasetu: ' + str(len(os.listdir(pokemon_data_directory))))

pokemon_card = cv2.imread(pokemon_data_directory + os.listdir(pokemon_data_directory)[0])
print("Oblik slika iz skupa pokemoni: " + str(pokemon_card.shape) + ". Rezolucija: 256x256, RGB slika\n")

# plt.imshow(pokemon_card)

vijesti_data_directory = './datasets/vijesti/'
print('Broj podataka u vijesti datasetu: ' + str(len(os.listdir(vijesti_data_directory))))

vijest_article = cv2.imread(vijesti_data_directory + os.listdir(vijesti_data_directory)[0])
print("Oblik slika iz skupa vijesti: " + str(vijest_article.shape) + ". Rezolucija: 256x256, RGB slika\n")

lzn_data_directory = './datasets/lud_zbunjen_normalan/'
print('Broj podataka u lzn datasetu: ' + str(len(os.listdir(lzn_data_directory))))

lzn_image = cv2.imread(lzn_data_directory + os.listdir(lzn_data_directory)[0])
print("Oblik slika iz skupa lzn: " + str(lzn_image.shape) + ". Rezolucija: 256x256, RGB slika\n")

red_dress_directory = './datasets/red_dress/'
print('Broj podataka u haljini datasetu: ' + str(len(os.listdir(red_dress_directory))) + '\n')

simpsons_data_directory = './datasets/simpsons/'
print('Broj podataka u simpsons datasetu: ' + str(len(os.listdir(simpsons_data_directory))) + '\n')

# Definisanja skupa podataka za treniranje
training_dataset = simpsons_data_directory
Broj podataka u pokemon datasetu: 819
Oblik slika iz skupa pokemoni: (256, 256, 3). Rezolucija: 256x256, RGB slika

Broj podataka u vijesti datasetu: 1355
Oblik slika iz skupa vijesti: (256, 256, 3). Rezolucija: 256x256, RGB slika

Broj podataka u lzn datasetu: 860
Oblik slika iz skupa lzn: (256, 256, 3). Rezolucija: 256x256, RGB slika

Broj podataka u haljini datasetu: 800

Broj podataka u simpsons datasetu: 9877

In [97]:
# Prikaz uzorka skupa podataka
def showDatasetSampleGrid(data_directory, num_of_images, grid_size, title):
    # Grid dimenzije
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(8, 8))
    # Učitavanje nasumičnih slika
    for ax in axes.flatten():
        rand_index = random.randrange(800)
        dataset_image = os.listdir(data_directory)[rand_index]
        img = cv2.imread(data_directory + dataset_image, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        ax.imshow(img)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(dataset_image.split('.')[0])

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Pokemon dataset
# showDatasetSampleGrid(pokemon_data_directory, 800, 5, 'Pokemon dataset')
# showDatasetSampleGrid(vijesti_data_directory, 1300, 5, 'Vijesti dataset')
showDatasetSampleGrid(lzn_data_directory, 800, 5, 'Lud zbunjen normalan dataset')
# showDatasetSampleGrid(red_dress_directory, 800, 5, 'Crvene haljine dataset')
# showDatasetSampleGrid(simpsons_data_directory, 8000, 5, 'Simpsons dataset')

Inicijalizacija težina¶

  • Težine su parametar u neuralnim mrežama koje transformišu ulazne podatke unutar skrivenih slojeva. Ovo su vrijednosti koje su povezane sa svakim ulazom i one prenose važnost odgovarajućeg feature-a u završnog izlazu. Ovdje se inicijalna težina postavlja od normalne distribucije sa srednjom vrijednošu 0 i standardnom devijacijom 0.02

Referenca: https://proceedings.neurips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf

In [200]:
def initialize_weights_v2(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
def initialize_weights(model):
    for m in model.modules():
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)
        
seed = random.randint(1, 10000)
random.seed(seed)
torch.manual_seed(seed)
Out[200]:
<torch._C.Generator at 0x20d56ca1050>

Klasa GAN modela¶

Input¶

  • in_channels: broj ulaznih kanala za sloj
  • out_channels: broj izlaznih kanala za sloj
  • kernel_size: dimenzija kernela
  • stride: korak kernela (default: 2) koristi se umjesto pooling slojeva
  • padding: padding za kernel (default: 1)
  • bn: da li se koristi batch normalizacija (default: True)
  • leaky_rely je inplace da sačuva memorije (CUDA limitacija)

in_channels (input features) i out_channels definišu veličinu težina

In [201]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_channels, features_g, resize=False):
        super(Generator, self).__init__()
        self.resize = resize
        if (self.resize == False):
            self.gen = nn.Sequential(
                self.layer(in_channels=z_dim, out_channels=features_g*64, kernel_size=4, stride=1, padding=0, bn=True), # 1x1
                self.layer(in_channels=features_g*64, out_channels=features_g*32, kernel_size=4, stride=2, padding=1, bn=True), # 4x4
                self.layer(in_channels=features_g*32, out_channels=features_g*16, kernel_size=4, stride=2, padding=1, bn=True), # 8x8
                self.layer(in_channels=features_g*16, out_channels=features_g*8, kernel_size=4, stride=2, padding=1, bn=True), # 16x16
                self.layer(in_channels=features_g*8, out_channels=features_g*4, kernel_size=4, stride=2, padding=1, bn=True), #32x32
                self.layer(in_channels=features_g*4, out_channels=features_g*2, kernel_size=4, stride=2, padding=1, bn=True), # 64x64
                nn.ConvTranspose2d(in_channels=features_g*2, out_channels=img_channels, kernel_size=4, stride=2, padding=1), # 256x256
                nn.Tanh() # Nair & Hinton, 2010 -> mapira (normalizira) na vrijednosti pravih slika
            )
        else:
            self.gen = nn.Sequential(
                self.layer(in_channels=z_dim, out_channels=features_g*8, kernel_size=4, stride=1, padding=0, bn=True), # 1x1
                self.layer(in_channels=features_g*8, out_channels=features_g*4, kernel_size=4, stride=2, padding=1, bn=True), # 4x4
                self.layer(in_channels=features_g*4, out_channels=features_g*2, kernel_size=4, stride=2, padding=1, bn=True), # 8x8
                self.layer(in_channels=features_g*2, out_channels=features_g, kernel_size=4, stride=2, padding=1, bn=True), # 16x16
                # self.layer(in_channels=features_g, out_channels=features_g, kernel_size=4, stride=2, padding=1, bn=True), # 32x32
                nn.ConvTranspose2d(in_channels=features_g, out_channels=img_channels, kernel_size=4, stride=2, padding=1, bias=False), # 64x64
                nn.Tanh()
            )
    
    def layer(self, in_channels, out_channels, kernel_size, stride, padding, bn=True):
        if bn:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(num_features=out_channels),
                nn.ReLU(True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
                nn.ReLU(True)
            )
    
    def forward(self, input):
        return self.gen(input)

class Discriminator(nn.Module):
    def __init__(self, img_channels, features_d, resize=False):
        super(Discriminator, self).__init__()
        self.resize = resize
        if (self.resize == False):
            self.disc = nn.Sequential(
                # 3x256x256
                self.layer(in_channels=img_channels, out_channels=features_d, kernel_size=4, stride=2, padding=1, bn=False), # fdx128x128
                self.layer(in_channels=features_d, out_channels=features_d*2, kernel_size=4, stride=2, padding=1, bn=True), # 64x64
                self.layer(in_channels=features_d*2, out_channels=features_d*4, kernel_size=4, stride=2, padding=1, bn=True), # 32x32
                self.layer(in_channels=features_d*4, out_channels=features_d*8, kernel_size=4, stride=2, padding=1, bn=True), #16x16
                self.layer(in_channels=features_d*8, out_channels=features_d*16, kernel_size=4, stride=2, padding=1, bn=True), #8x8
                self.layer(in_channels=features_d*16, out_channels=features_d*32, kernel_size=4, stride=2, padding=1, bn=True), # 4x4
                nn.Conv2d(in_channels=features_d*32, out_channels=1, kernel_size=4, stride=2, padding=0), # 1x1
                # nn.Sigmoid()
            ) 
        else:
            self.disc = nn.Sequential(
                # 3x64x64
                self.layer(in_channels=img_channels, out_channels=features_d, kernel_size=4, stride=2, padding=1, bn=False), # 64x64
                self.layer(in_channels=features_d, out_channels=features_d*2, kernel_size=4, stride=2, padding=1, bn=True), # 32x32
                self.layer(in_channels=features_d*2, out_channels=features_d*4, kernel_size=4, stride=2, padding=1, bn=True), # 16x16
                self.layer(in_channels=features_d*4, out_channels=features_d*8, kernel_size=4, stride=2, padding=1, bn=True), #8x8
                nn.Conv2d(in_channels=features_d*8, out_channels=1, kernel_size=4, stride=2, padding=0, bias=False), # 4x4
                # nn.Sigmoid()
            )    
    def layer(self, in_channels, out_channels, kernel_size, stride, padding, bn=True):
        if bn:
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(num_features=out_channels), # Batch normalizacija 
                nn.LeakyReLU(0.1, inplace=True)
            )
        return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
                nn.LeakyReLU(0.1, inplace=True)
            )
    def forward(self, input):
        return self.disc(input)
In [202]:
# Testiranje generatora

N, in_channels, H, W = 8, 3, 256, 256 # 3x256x256 slika sa 8 slika u batchu
noise_dim = 100 # z dimenzija
z = torch.randn((N, noise_dim, 1, 1))

generated_image = Generator(noise_dim, in_channels, 8)
image = generated_image(z)[0]

image.permute(1, 2, 0)
image = image.reshape(256,256,3)
plt.imshow((image.detach().numpy() * 255).astype(np.uint8))
plt.axis('off')
Out[202]:
(-0.5, 255.5, 255.5, -0.5)
In [203]:
def testGAN(resize=False):
    N, in_channels, H, W = 8, 3, 256, 256 # 3x256x256 slika sa 8 slika u batchu
    if (resize):
        H = 64
        W = 64
    noise_dim = 100 # z dimenzija
    x = torch.randn((N, in_channels, H, W))
    z = torch.randn((N, noise_dim, 1, 1))
    disc = Discriminator(in_channels, 8, resize=resize)
    gen = Generator(noise_dim, in_channels, 8, resize=resize)

    print(f'Diskriminator oblik: {disc(x).shape}')
    assert disc(x).shape == (N, 1, 1, 1), "Diskriminator nije prošao test" 

    print(f'Generator oblik: {gen(z).shape}')
    assert gen(z).shape == (N, in_channels, H, W), "Generator nije prošao test"
    print("Uspješno!")

testGAN()
Diskriminator oblik: torch.Size([8, 1, 1, 1])
Generator oblik: torch.Size([8, 3, 256, 256])
Uspješno!
In [12]:
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
class LoadDataset(torch.utils.data.Dataset):
    def __init__(self, img_path, resize):
        super(LoadDataset, self).__init__()
        self.img_path = img_path
        self.resize = resize;
        print(f'Putanja: {self.img_path}')
    
    def __len__(self):
        return len(os.listdir(self.img_path))
    
    def __getitem__(self, idx):
        pth = os.listdir(self.img_path)[idx]
        img = cv2.imread(self.img_path + pth, cv2.IMREAD_COLOR)
        cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = torch.tensor(img)
        img = img.permute(2, 0, 1)
        if (self.resize): 
            img = torchvision.transforms.functional.resize(img, (64,64), interpolation=2)
        return img/255.0, 1
In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 0.0002
BETA = (0.5, 0.999)
BATCH_SIZE = 512 # 16, 256 - previse zahtjeva vremena, 128 - najbolje
# IMAGE_SIZE = 256
IMG_CHANNELS = 3
Z_DIM = 100 #100
FEATURES_DISC = 64
FEATURES_GEN = 64
NUM_EPOCHS = 1000

print(DEVICE)
cuda
In [13]:
# lzn_data_directory, vijesti_data_directory, pokemon_data_directory

dataset = LoadDataset(img_path=training_dataset, resize=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

for test_images, test_labels in dataloader:  
    sample_image = test_images[0]    
    sample_label = test_labels[0]
    # imgplot = plt.imshow(sample_image.permute(1, 2, 0))
    # plt.show()
    print(sample_image.shape, sample_label)
Putanja: ./datasets/simpsons/
C:\Users\User\anaconda3\lib\site-packages\torchvision\transforms\functional.py:442: UserWarning: Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
  warnings.warn(
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
torch.Size([3, 64, 64]) tensor(1)
In [207]:
gen = Generator(Z_DIM, IMG_CHANNELS, FEATURES_GEN, resize=True).to(DEVICE)
disc = Discriminator(IMG_CHANNELS, FEATURES_DISC, resize=True).to(DEVICE)

if (DEVICE.type == 'cuda'):
    netG = nn.DataParallel(gen, list(range(1)))
    netD = nn.DataParallel(disc, list(range(1)))

initialize_weights(gen)
initialize_weights(disc)

if (DEVICE.type == 'cuda'):
    netG.apply(initialize_weights)
    netD.apply(initialize_weights)

testGAN(resize=True)
Diskriminator oblik: torch.Size([8, 1, 1, 1])
Generator oblik: torch.Size([8, 3, 64, 64])
Uspješno!
In [208]:
opt_gen = optim.Adam(netG.parameters(), lr=LEARNING_RATE, betas=BETA)
opt_disc = optim.Adam(netD.parameters(), lr=LEARNING_RATE, betas=BETA)
criterion = nn.BCEWithLogitsLoss()
# criterion = nn.BCELoss()
fixed_noise = torch.randn((64, Z_DIM, 1, 1)).to(DEVICE)
In [209]:
gen.train();
disc.train();

Treniranje GAN modela¶

In [224]:
print("TRENIRANJE NA DATASET-U: ", training_dataset)

# Optimizacija GAN treniranja
# https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9

# Temp liste za mjerenje rezultata i prikaz slika
img_list = []
epoha_list = []
G_losses = []
D_losses = []

# Tehnika label smoothing - https://towardsdatascience.com/gan-ways-to-improve-gan-performance-acf37f9f59b
smoothing = 0.1
real_label = 1. - smoothing
fake_label = 0. + smoothing

# TODO: dodati i n_kritik (Wasserstein GAN)

# Prikaz slika prilikom treniranja
show_images_while_training = True
save_image_iteration = 10 # 10

######## Treniranje GAN modela
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(dataloader, 0):
        # Prolaz kroz grupe slika

        # --> Treniraj sa skupom pravih slika
        # Nakon svake iteracije postaviti gradijent na 0, jer se u suprotnom gubitak akumulira na listovima
        netD.zero_grad()
        
        # Preuzmi informacije pravih slika (slika, velicina, labela sa smoothing)
        real_img = data[0].to(DEVICE)
        b_size = real_img.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=DEVICE)
        
        # --> Diskriminator se trenira na pravim slimama
        output = netD(real_img).view(-1)
        # Izračunaj gubitak zavisno od klasifikacije diskriminatora (ovdje se očekuje da diskriminator vrati 1)
        errD_real = criterion(output, label)
        # Izračunaj gradijente za D i ažuriraj kroz backward, skalar tensor za graf
        errD_real.backward()

        # --> Treniraj sa skupom lažnih slika
        # Generiši grupu latentnih vektora (generator) koristeći veličinu pravih slika
        noise = torch.randn(b_size, Z_DIM, 1, 1, device=DEVICE)
        # Generiši lažne slike i postavi labelu na 0
        fake = netG(noise)
        # Iskoristi istu varijablu ali napuni sa lažnom labelom
        label.fill_(fake_label)

        # --> Klasificiraj lažne slike (detach se mora koristiti da se izbaci iz grafa)
        output = netD(fake.detach()).view(-1)
        # Izračunaj gubitak za G
        errD_fake = criterion(output, label)
        # Izračunaj gradijente za G
        errD_fake.backward()
        # Izračunati gubitak za D kao sumu pravih i lažnih
        errD = errD_real + errD_fake
        # Ažuriraj optimizator
        opt_disc.step()

        # --> Generiši slike nakon ažuiranja diskriminatora i Ažuriraj Generator
        # Nakon svake iteracije postaviti gradijent na 0 za G
        netG.zero_grad()
        label.fill_(real_label)  # Popuni labele, ovdje su prave
        # Provjeri za generisane slike D
        output = netD(fake).view(-1)
        # Zavisno od klasifikacije, nađi gubitak za G
        errG = criterion(output, label)
        # Izračunati gradijent
        errG.backward()
        # Ažuriraj optimizator
        opt_gen.step()

        G_losses.append(errG.item())
        D_losses.append(errD.item())    
        
    # Spasi rezultat generatora korištenjem torchvision.utils biblioteke
    if (epoch % save_image_iteration == 9) or ((epoch == NUM_EPOCHS-1) and (i == len(dataloader)-1)):
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            # detach se koristi kad se ne koristi gradijent, pošto Pytorch spašava sve tensore kroz direktni graf
        img_list.append(torch.flip(vutils.make_grid(fake, padding=2, normalize=True), [-1]))
        epoha_list.append(epoch+1)
        # prikaži slike prilikom treniranja (za veće epohe, ukoliko dođe do zastoja)
        if (show_images_while_training):
            image = torch.flip(vutils.make_grid(fake, padding=2, normalize=True).permute(1,2,0), [-1])
            plt.imshow(image)
            plt.show()
        
    print('Epoha [{:d}/{:d}] -> d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
            epoch+1, NUM_EPOCHS, errD.item(), errG.item()))
    
print("Treniranje završeno.")
TRENIRANJE NA DATASET-U:  ./datasets/simpsons/
C:\Users\User\anaconda3\lib\site-packages\torchvision\transforms\functional.py:442: UserWarning: Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
  warnings.warn(
Epoha [1/1000] -> d_loss: 0.7437 | g_loss: 1.2477
Epoha [2/1000] -> d_loss: 0.7466 | g_loss: 2.7567
Epoha [3/1000] -> d_loss: 0.9323 | g_loss: 0.9265
Epoha [4/1000] -> d_loss: 0.7425 | g_loss: 1.4641
Epoha [5/1000] -> d_loss: 0.7188 | g_loss: 1.6329
Epoha [6/1000] -> d_loss: 1.0182 | g_loss: 3.3703
Epoha [7/1000] -> d_loss: 1.1227 | g_loss: 0.9412
Epoha [8/1000] -> d_loss: 0.7978 | g_loss: 2.2107
Epoha [9/1000] -> d_loss: 1.3784 | g_loss: 3.9970
Epoha [10/1000] -> d_loss: 0.7825 | g_loss: 1.3030
Epoha [11/1000] -> d_loss: 0.7777 | g_loss: 2.7425
Epoha [12/1000] -> d_loss: 0.6999 | g_loss: 2.1127
Epoha [13/1000] -> d_loss: 0.7047 | g_loss: 1.6068
Epoha [14/1000] -> d_loss: 1.0943 | g_loss: 3.7395
Epoha [15/1000] -> d_loss: 1.0073 | g_loss: 3.3448
Epoha [16/1000] -> d_loss: 0.7172 | g_loss: 2.0708
Epoha [17/1000] -> d_loss: 0.6821 | g_loss: 2.1056
Epoha [18/1000] -> d_loss: 0.8286 | g_loss: 2.0734
Epoha [19/1000] -> d_loss: 0.8047 | g_loss: 1.3893
Epoha [20/1000] -> d_loss: 0.7147 | g_loss: 1.8753
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_36428\1364346884.py in <module>
     21 ######## Treniranje GAN modela
     22 for epoch in range(NUM_EPOCHS):
---> 23     for i, data in enumerate(dataloader, 0):
     24         # Prolaz kroz grupe slika
     25 

~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627                 self._reset()  # type: ignore[call-arg]
--> 628             data = self._next_data()
    629             self._num_yielded += 1
    630             if self._dataset_kind == _DatasetKind.Iterable and \

~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    669     def _next_data(self):
    670         index = self._next_index()  # may raise StopIteration
--> 671         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672         if self._pin_memory:
    673             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

~\AppData\Local\Temp\ipykernel_36428\3590488943.py in __getitem__(self, idx)
     11 
     12     def __getitem__(self, idx):
---> 13         pth = os.listdir(self.img_path)[idx]
     14         img = cv2.imread(self.img_path + pth, cv2.IMREAD_COLOR)
     15         cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

KeyboardInterrupt: 

Prikaz gubitka¶

  • Treniranje generatora i diskriminatora funkcioniše kao minimax metoda gdje se generator i diskriminator bore da pobijede između sebe. Za razliku od standardnog treniranja neuronskih mreža gdje se očekuje da se gubitak smanji prilikom treniranja (ili dolazi do overfitanja modela) u GAN modelu očekuje se da dolazi do turbulencije kroz treniranje jer nije prihvatljivo da generator ili diskriminator imaju gubitak blizu 0.
  • Ukoliko diskriminator ima gubitak 0 dolazi do perfektnog prepoznavanja pravih slika od generisanih što označava da generator ne generiše slike koje će prevariti diskriminator. Pošto se od diskriminatora ne očekuje vrijednost 0, tada ni generator ne može imati nulti gubitak jer u toj situaciji diskriminator nije dobro podešen.
In [225]:
# Grafički prikaz gubitaka
plt.figure(figsize=(10,5))
plt.title("Gubici prilikom treniranja")
plt.plot(G_losses,label="Generator")
plt.plot(D_losses,label="Diskriminator")
plt.xlabel("Iteracije")
plt.ylabel("Gubitak")
plt.legend()
plt.show()
In [214]:
# Prikaz generisanih slika kroz epohe
fig, ax = plt.subplots()
plt.axis("off")

container = []
generated_images = [[plt.imshow(torch.flip(np.transpose(i,(1,2,0)), [-1]), animated=True)] for i in img_list]


for i in range(len(epoha_list)):
    image_grid = generated_images[i][0]
    title = ax.text(0.5,1.05,"Epoha {}".format(epoha_list[i]), 
                    size=plt.rcParams["axes.titlesize"],
                    ha="center", transform=ax.transAxes, )
    container.append([image_grid, title])    


ani = animation.ArtistAnimation(fig, container, interval=700, repeat_delay=700, blit=False)

HTML(ani.to_jshtml())
Out[214]:

Spašavanje modela i generisanje podataka kroz Generator¶

  • Model se spašava kroz funkciju torch.save gdje se definiše putanja. Zatim je moguće pokrenuti istrenirani generator i korististi za generisanje slika

  • Slike se generišu kroz generator koji vraća tensor niz koji predstavlja sliku. Kako bi se slika prikaza i sačuvala potrebno je transponírati sliku korištenjem procesora

In [223]:
# torch.save(gen.state_dict(), 'generator.pth')

N, in_channels, H, W = 8, 3, 64, 64 # 3x256x256 slika sa 8 slika u batchu
noise_dim = 100 # z dimenzija
z = torch.randn((N, noise_dim, 1, 1))

generated_image = np.transpose(vutils.make_grid(netG(z).detach().cpu()),(1,2,0))
print(generated_image.shape)

# plt.imshow((fake.detach().cpu().numpy() * 255).astype(np.uint8))
# plt.axis('off')
torch.Size([68, 530, 3])

CGAN (Conditional GAN)¶

Hipervarijable: https://ijeee.edu.iq/Papers/Vol18-Issue1/1570796090.pdf

GAN model adaptacija za CGAN: https://www.cs.toronto.edu/~lczhang/321/lec/gan_notes.html

  • U ovom CGAN modelu koristi se Linear sloj (y = x*AT + b) gdje je X -> ulazni podaci, A -> težina, b -> bias, generator kroz linear sloj povećava dimenziju.

  • A => dimenzija in_features X out_features

In [92]:
from torchvision import datasets
from torchvision.transforms import ToTensor

training_dataset = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

data_loader = torch.utils.data.DataLoader(training_dataset, batch_size=32, shuffle=True)
In [93]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # n_klasa = 10 , z_dim = 10
        self.label_emb = nn.Embedding(10, 10)
        
        self.gen = nn.Sequential(
            nn.Linear(110, 256), # input_feature = 110, out_features = 256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh() # normalizacija
        )
    
    def forward(self, z, labels):
        z = z.view(z.size(0), 100)
        c = self.label_emb(labels)
        # Ulancanje slike sa labelom
        x = torch.cat([z, c], 1)
        out = self.gen(x)
        return out.view(x.size(0), 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # n_klasa = 10 , z_dim = 10
        self.label_emb = nn.Embedding(10, 10)
        
        self.disc = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3), # Overfitanje, oslabiti diskriminator 0.3
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid() # klasifikacija
        )
    
    def forward(self, x, labels):
        x = x.view(x.size(0), 784)
        c = self.label_emb(labels)
        x = torch.cat([x, c], 1)
        out = self.disc(x)
        return out.squeeze()
In [94]:
# https://www.researchgate.net/figure/The-architecture-of-conditional-GANCGAN_fig3_350115869

generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

criterion = nn.BCELoss()
# LR = 0.0002 mozda
opt_disc = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=BETA)
opt_gen = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=BETA)
In [95]:
print("TRENIRANJE NA DATASET-U: ", training_dataset)

BATCH_SIZE = 32
NUM_EPOCHS = 50 # 100 bolje
for epoch in range(NUM_EPOCHS):
    # Ovdje pored slika postoje i labele
    for i, (images, labels) in enumerate(data_loader):
        
        step = epoch * len(data_loader) + i + 1
        real_images = Variable(images).to(DEVICE)
        labels = Variable(labels).to(DEVICE)
        generator.train()
        
        opt_disc.zero_grad()

        # Treniranje prave slike
        realD = discriminator(real_images, labels)
        errD_real = criterion(realD, Variable(torch.ones(BATCH_SIZE)).to(DEVICE))

        # Treniranje lažne slike
        z = Variable(torch.randn(BATCH_SIZE, 100)).to(DEVICE)
        fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, BATCH_SIZE))).to(DEVICE)
        fake = generator(z, fake_labels)
        fakeD = discriminator(fake, fake_labels)
        errD_fake = criterion(fakeD, Variable(torch.zeros(BATCH_SIZE)).to(DEVICE))

        # Mjerenje greške i ažuriranje optimizatora
        errD = errD_real + errD_fake
        errD.backward()
        opt_disc.step()

        # Generiši slike nakon treniranja diskriminatora
        opt_gen.zero_grad()
        z = Variable(torch.randn(BATCH_SIZE, 100)).to(DEVICE)
        fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, BATCH_SIZE))).to(DEVICE)
        fake_images = generator(z, fake_labels)
        output = discriminator(fake_images, fake_labels)
        errG = criterion(output, Variable(torch.ones(BATCH_SIZE)).to(DEVICE))
        errG.backward()
        opt_gen.step()
            
    print('Epoha [{:d}/{:d}] -> d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
    epoch+1, NUM_EPOCHS, errD.item(), errG.item()))

        
print("Treniranje završeno.")
TRENIRANJE NA DATASET-U:  Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()
Epoha [1/50] -> d_loss: 1.2395 | g_loss: 0.8052
Epoha [2/50] -> d_loss: 1.3016 | g_loss: 0.8841
Epoha [3/50] -> d_loss: 1.2923 | g_loss: 0.9236
Epoha [4/50] -> d_loss: 1.2733 | g_loss: 0.8232
Epoha [5/50] -> d_loss: 1.3053 | g_loss: 0.8601
Epoha [6/50] -> d_loss: 1.1837 | g_loss: 0.8448
Epoha [7/50] -> d_loss: 1.2879 | g_loss: 1.0121
Epoha [8/50] -> d_loss: 1.2161 | g_loss: 0.7655
Epoha [9/50] -> d_loss: 1.2055 | g_loss: 1.0355
Epoha [10/50] -> d_loss: 1.3106 | g_loss: 0.8585
Epoha [11/50] -> d_loss: 1.1102 | g_loss: 0.9037
Epoha [12/50] -> d_loss: 1.2440 | g_loss: 0.9412
Epoha [13/50] -> d_loss: 1.1783 | g_loss: 1.1122
Epoha [14/50] -> d_loss: 1.1419 | g_loss: 1.0241
Epoha [15/50] -> d_loss: 1.2759 | g_loss: 0.9626
Epoha [16/50] -> d_loss: 1.1628 | g_loss: 0.9596
Epoha [17/50] -> d_loss: 1.2131 | g_loss: 0.8438
Epoha [18/50] -> d_loss: 1.3995 | g_loss: 0.9768
Epoha [19/50] -> d_loss: 1.1178 | g_loss: 1.0337
Epoha [20/50] -> d_loss: 1.1342 | g_loss: 0.9816
Epoha [21/50] -> d_loss: 1.1998 | g_loss: 0.7885
Epoha [22/50] -> d_loss: 1.2256 | g_loss: 0.7497
Epoha [23/50] -> d_loss: 1.1240 | g_loss: 1.0188
Epoha [24/50] -> d_loss: 1.1099 | g_loss: 0.9433
Epoha [25/50] -> d_loss: 1.1521 | g_loss: 1.0552
Epoha [26/50] -> d_loss: 1.0232 | g_loss: 1.1746
Epoha [27/50] -> d_loss: 1.1246 | g_loss: 0.9809
Epoha [28/50] -> d_loss: 1.1667 | g_loss: 1.1981
Epoha [29/50] -> d_loss: 1.1132 | g_loss: 1.4857
Epoha [30/50] -> d_loss: 1.1716 | g_loss: 1.5477
Epoha [31/50] -> d_loss: 1.0883 | g_loss: 1.1025
Epoha [32/50] -> d_loss: 1.1271 | g_loss: 1.0805
Epoha [33/50] -> d_loss: 1.1692 | g_loss: 1.2308
Epoha [34/50] -> d_loss: 1.3571 | g_loss: 1.2528
Epoha [35/50] -> d_loss: 0.7786 | g_loss: 1.3195
Epoha [36/50] -> d_loss: 0.9004 | g_loss: 1.3807
Epoha [37/50] -> d_loss: 1.0127 | g_loss: 1.0535
Epoha [38/50] -> d_loss: 1.1810 | g_loss: 1.0611
Epoha [39/50] -> d_loss: 0.8524 | g_loss: 1.0565
Epoha [40/50] -> d_loss: 0.8887 | g_loss: 1.0766
Epoha [41/50] -> d_loss: 0.8922 | g_loss: 1.5439
Epoha [42/50] -> d_loss: 0.9416 | g_loss: 1.3468
Epoha [43/50] -> d_loss: 0.8473 | g_loss: 1.8123
Epoha [44/50] -> d_loss: 0.8508 | g_loss: 1.5014
Epoha [45/50] -> d_loss: 0.9337 | g_loss: 1.2128
Epoha [46/50] -> d_loss: 1.0708 | g_loss: 0.9405
Epoha [47/50] -> d_loss: 1.1823 | g_loss: 1.1478
Epoha [48/50] -> d_loss: 1.0394 | g_loss: 1.4508
Epoha [49/50] -> d_loss: 0.9898 | g_loss: 1.3145
Epoha [50/50] -> d_loss: 1.0058 | g_loss: 1.4612
Treniranje završeno.
In [67]:
def generate_image(generator, img_class):
    z = torch.randn(1, 100, device=DEVICE)
    label = torch.full((1,), img_class, dtype=torch.long, device=DEVICE)
    img = generator(z, label).data.cpu()
    img = 0.5 * img + 0.5
    return np.transpose(vutils.make_grid(img),(1,2,0))


for i in range(10):
    gen_img = generate_image(generator, i)

    plt.figure(figsize = (2,2))
    plt.axis('off')
    plt.imshow((gen_img.numpy() * 255).astype(np.uint8), aspect='auto')
In [ ]: